import pandas as pd
import os
import sys
import json
import collections
import random
import math
import argparse
import time
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from utils import print_local_time
from model_base import BubbleEmbedBase
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

class DataTrain(Dataset):
    def __init__(self, args, tokenizer,data_path):
        self.args = args
        self.data = self._load_data(data_path)
        self.tokenizer = tokenizer
    
    def _load_data(self, data_path):
        # Load the data from the csv file
        data = pd.read_csv(data_path, header=0, index_col=None)
        return data
    
    def __len__(self):
        return len(self.data)

    def process_encode(self, encode):
        # Process the encode to get the input ids, attention mask, and token type ids
        input_ids = encode['input_ids'].squeeze(0)
        attention_mask = encode['attention_mask'].squeeze(0)
        token_type_ids = encode['token_type_ids'].squeeze(0)
        if self.args.cuda:
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        encode["input_ids"] = input_ids
        encode["attention_mask"] = attention_mask
        encode["token_type_ids"] = token_type_ids
        return encode
    
    def generate_train_instance_id(self, idx):
        questionA = self.data.iloc[idx]['question1']
        questionB = self.data.iloc[idx]['question2']
        # Get the tokenized questions
        encodeA = self.tokenizer(questionA,padding="max_length", truncation=True, return_tensors='pt',max_length=40)
        processedA = self.process_encode(encodeA)
        encodeB = self.tokenizer(questionB,padding="max_length", truncation=True, return_tensors='pt',max_length=40)
        processedB = self.process_encode(encodeB)
        # Get the similarity score
        prob = self.data.iloc[idx]['is_duplicate']
        return processedA, processedB, prob

    def __getitem__(self, idx):
        encode_A, encode_B, prob = self.generate_train_instance_id(idx)
        if self.args.cuda:
            prob = torch.tensor(prob, dtype=torch.float).cuda()
        else:
            prob = torch.tensor(prob, dtype=torch.float)
        return encode_A, encode_B, prob
    
class DataTest(Dataset):
    def __init__(self, args, tokenizer,data_path):
        self.args = args
        self.data = self._load_data(data_path)
        self.tokenizer = tokenizer
    
    def _load_data(self, data_path):
        # Load the data from the csv file
        data = pd.read_csv(data_path, header=0, index_col=None)
        return data
    
    def __len__(self):
        return len(self.data)

    def process_encode(self, encode):
        # Process the encode to get the input ids, attention mask, and token type ids
        input_ids = encode['input_ids'].squeeze(0)
        attention_mask = encode['attention_mask'].squeeze(0)
        token_type_ids = encode['token_type_ids'].squeeze(0)
        if self.args.cuda:
            input_ids = input_ids.cuda()
            attention_mask = attention_mask.cuda()
            token_type_ids = token_type_ids.cuda()
        encode["input_ids"] = input_ids
        encode["attention_mask"] = attention_mask
        encode["token_type_ids"] = token_type_ids
        return encode
    
    def generate_test_instance_id(self, idx):
        questionA = self.data.iloc[idx]['question1']
        questionB = self.data.iloc[idx]['question2']
        # Get the tokenized questions
        encodeA = self.tokenizer(questionA,padding="max_length", truncation=True, return_tensors='pt',max_length=40)
        processedA = self.process_encode(encodeA)
        encodeB = self.tokenizer(questionB,padding="max_length", truncation=True, return_tensors='pt',max_length=40)
        processedB = self.process_encode(encodeB)
        # Get the similarity score
        prob = self.data.iloc[idx]['is_duplicate']
        qid1 = self.data.iloc[idx]["qid1"]
        qid2 = self.data.iloc[idx]["qid2"]
        return processedA, processedB, prob, qid1, qid2

    def __getitem__(self, idx):
        encode_A, encode_B, prob, q1, q2 = self.generate_test_instance_id(idx)
        if self.args.cuda:
            prob = torch.tensor(prob, dtype=torch.float).cuda()
        else:
            prob = torch.tensor(prob, dtype=torch.float)
        return encode_A, encode_B, prob, q1, q2

class LabelClassfnExp(object):
    def __init__(self,args):
        self.args = args
        self.tokenizer = self.__load_tokenizer__()
        self.train_loader, self.train_set = self.load_data(self.args, self.tokenizer,"train")
        self.test_loader, self.test_set = self.load_data(self.args, self.tokenizer, "test")
        self.model = BubbleEmbedBase(args)
        self.optimizer_pretrain, self.optimizer_projection = self._select_optimizer()
        self._set_device()
        self._set_seed(self.args.seed)
        self.setting = self.args
        self.exp_setting = (
            str(self.args.dataset)
            + "_"
            + str(self.args.expID)
            + "_"
            + str(self.args.epochs)
            + "_"
            + str(self.args.embed_size)
            + "_"
            + str(self.args.batch_size)
            + "_"
            + str(self.args.lr)
            + "_"
            + str(self.args.phi)
            + "_"
            + str(self.args.regularwt)
            + "_"
            + str(self.args.probwt)
            + "_"
            + str(self.args.seed)
            + "_"
            + str(self.args.version)
        )

        # Loss functions
        self.regular_loss = nn.MSELoss()
        self.contain_loss = nn.MSELoss()
        self.prob_loss = nn.BCELoss()
        self.bubble_size_loss = nn.MSELoss()

        # Additional parameters
        self.num_dimensions = self.args.embed_size
        self.volume_factor = (math.pi ** (args.embed_size / 2)) / math.gamma((args.embed_size / 2) + 1)

    def load_data(self, args, tokenizer, mode):
        data_dir = f"../data/{args.dataset}/processed"
        if mode == "train":
            shuffle_flag=True
            data_path = os.path.join(data_dir, f'train_questions.csv')
            dataset = DataTrain(args, tokenizer, data_path)
        elif mode == "test":
            shuffle_flag=False
            data_path = os.path.join(data_dir, f'test_questions.csv')
            dataset = DataTest(args, tokenizer, data_path)
        
        dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=shuffle_flag)
        return dataloader, dataset
    
    def __load_tokenizer__(self):
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        print("Tokenizer Loaded!")
        return tokenizer
    
    def _select_optimizer(self):
        pre_train_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("pre_train")
                ],
                "weight_decay": 0.0,
            },
        ]
        projection_parameters = [
            {
                "params": [
                    p
                    for n, p in self.model.named_parameters()
                    if n.startswith("projection")
                ],
                "weight_decay": 0.0,
            },
        ]

        if self.args.optim == "adam":
            optimizer_pretrain = optim.Adam(pre_train_parameters, lr=self.args.lr)
            optimizer_projection = optim.Adam(
                projection_parameters, lr=self.args.lr_projection
            )
        elif self.args.optim == "adamw":
            optimizer_pretrain = optim.AdamW(
                pre_train_parameters, lr=self.args.lr, eps=self.args.eps
            )
            optimizer_projection = optim.AdamW(
                projection_parameters, lr=self.args.lr_projection, eps=self.args.eps
            )

        return optimizer_pretrain, optimizer_projection

    def _set_device(self):
        if self.args.cuda:
            self.model = self.model.cuda()

    def center_distance(self, center1, center2):
        return torch.linalg.norm(center1 - center2, 2,-1)

    def bubble_volume(self,delta,temperature=0.1):
        # Ensure valid radii (avoid negative or zero values)
        valid_mask = (delta > 0).float()
        volume = self.volume_factor * (torch.pow(delta,self.num_dimensions))

        # Apply mask to set volume to 0 if radius is invalid
        return (volume * valid_mask)

    def bubble_regularization(self, delta):
        zeros = torch.zeros_like(delta)
        ones = torch.ones_like(delta)
        min_radius = torch.ones_like(delta) * self.args.phi
        
        # Create mask for bubbles smaller than minimum size
        small_bubble_mask = torch.where(delta < self.args.phi, ones, zeros)
        
        # Apply mask to focus loss only on small bubbles
        # Calculate MSE between actual and minimum radius for small bubbles
        regular_loss = self.bubble_size_loss(
            torch.mul(delta, small_bubble_mask), 
            torch.mul(min_radius, small_bubble_mask)
        )
        
        return regular_loss
    
    def containment_loss_cached(self, delta1, delta2, dist_center, tmask):
        # Whether bubble1 contains bubble2
        violation = (delta1 - delta2) - dist_center
        # Calculate the loss
        mask = (violation < 0).float() # This selects those bubbles that are not contained
        mask = mask * tmask
        # Apply mask to focus loss only on violations
        loss = self.contain_loss(violation*mask, torch.zeros_like(violation))
        return loss

    def disjoint_loss_cached(self,delta1,delta2,dist_center,inv_tmask):
        diff = delta1 + delta2 - dist_center
        mask = (diff > 0).float()
        mask = mask * (1-inv_tmask)
        loss = self.contain_loss(diff*mask, torch.zeros_like(diff))
        return loss
    
    def radial_intersection_cached(self, delta1, delta2, dist_center):
        sum_radius = delta1 + delta2
        if dist_center.ndim == 1:
            dist_center = dist_center.unsqueeze(1)
        mask = (dist_center < sum_radius).float()
        intersection_radius = mask * ((sum_radius - dist_center) / 2)
        intersection_radius = torch.min(intersection_radius, torch.min(delta1, delta2))
        return intersection_radius
        
    def condition_score_cached(self, radius_A, radius_B, dist_center):
        inter_delta = self.radial_intersection_cached(
            radius_A, radius_B, dist_center
        )
        mask = (inter_delta > 0).float()
        masked_inter_delta = inter_delta * mask
        # Conditioned on B
        score_pre = masked_inter_delta / radius_B
        scores = score_pre
        return scores.squeeze()
    
    def cond_prob_loss_cached(self, radius_A, radius_B, dist_center, targets):
        score = self.condition_score_cached(radius_A, radius_B, dist_center)
        score = score.clamp(1e-7, 1-1e-7)
        loss = self.prob_loss(score, targets)
        return loss

    def compute_loss(self, encode_A, encode_B, targets):
        center_A, radius_A = self.model(encode_A)
        center_B, radius_B = self.model(encode_B)
        c_dist = self.center_distance(center_A, center_B)

        # Regularization Loss
        regular_loss = self.bubble_regularization(radius_B)
        regular_loss += self.bubble_regularization(radius_A)

        # Calculating the conditional probability
        cond_prob_loss = self.cond_prob_loss_cached(radius_A, radius_B, c_dist, targets)
        cond_prob_loss += self.cond_prob_loss_cached(radius_B, radius_A, c_dist, targets)

        # Containment Loss
        containment_loss = self.containment_loss_cached(radius_A, radius_B, c_dist, targets)
        containment_loss += self.containment_loss_cached(radius_B, radius_A, c_dist, targets)
        disjoint_loss = self.disjoint_loss_cached(radius_A, radius_B, c_dist, targets)
        disjoint_loss += self.disjoint_loss_cached(radius_B, radius_A, c_dist, targets)

        loss = self.args.probwt * cond_prob_loss + self.args.regularwt * regular_loss + self.args.containwt * containment_loss
        # sys.exit(0)
        return loss
    
    def _set_seed(self, seed):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if self.args.cuda:
            torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True

    def train_one_step(self, it, encode_A, encode_B, targets):
        self.model.train()
        self.optimizer_pretrain.zero_grad()
        self.optimizer_projection.zero_grad()

        if self.args.cuda and not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets, dtype=torch.float).cuda()
        elif not isinstance(targets, torch.Tensor):
            targets = torch.tensor(targets, dtype=torch.float)

        loss = self.compute_loss(encode_A, encode_B, targets)
        loss.backward()
        self.optimizer_pretrain.step()
        self.optimizer_projection.step()
        return loss

    def train(self,checkpoint=None,save_path=None):
        self._set_seed(self.args.seed)
        time_tracker = []

        best_prec=0; best_recall=0; best_f1=0

        if checkpoint:
            self.model.load_state_dict(torch.load(checkpoint))
        if save_path is None:
            save_path = os.path.join("../result", self.args.dataset,"model")
            train_path = os.path.join("../result", self.args.dataset,"train")
            if not os.path.exists(save_path):
                os.makedirs(save_path)
            if not os.path.exists(train_path):
                os.makedirs(train_path)
        
        for epoch in tqdm(range(self.args.epochs)):
            train_loss = []
            
            epoch_time = time.time()
            print(f"Epoch {epoch+1}/{self.args.epochs}")
            for it, (encode_A, encode_B, targets) in tqdm(enumerate(self.train_loader), total = len(self.train_loader)):
                loss = self.train_one_step(it, encode_A, encode_B, targets)
                train_loss.append(loss.item())
            
            train_loss = np.average(train_loss)
            test_metrics = self.predict()

            if(test_metrics["Precision"] >= best_prec):
                if((test_metrics["Recall"] >= best_recall) or (test_metrics["F1"] >= best_f1)):
                    best_prec = test_metrics["Precision"]
                    best_recall = test_metrics["Recall"]
                    best_f1 = test_metrics["F1"]
                    torch.save(self.model.state_dict(), os.path.join(save_path, f"exp_model_{self.exp_setting}.checkpoint"))

            time_tracker.append(time.time() - epoch_time)
            print(
                "Epoch: {:04d}".format(epoch + 1),
                " train_loss:{:.05f}".format(train_loss),
                " Precision:{:.05f}".format(test_metrics["Precision"]),
                " Recall:{:.05f}".format(test_metrics["Recall"]),
                " F1:{:.05f}".format(test_metrics["F1"]),
                " epoch_time:{:.01f}s".format(time.time() - epoch_time),
                " remain_time:{:.01f}s".format(np.mean(time_tracker) * (self.args.epochs - (1 + epoch))),
                )   

            torch.save(self.model.state_dict(), os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str(epoch)+".checkpoint"))
            if epoch:
                os.remove(os.path.join("../result",self.args.dataset,"train","exp_model_"+self.exp_setting+"_"+str((epoch-1))+".checkpoint"))
    
    def metrics(self, pred, target, key=None):
        # Convert inputs to numpy arrays
        pred = np.array(pred)
        target = np.array(target)
        
        # Ensure binary values (0 or 1)
        # pred = np.round(pred).astype(int)
        target = np.round(target).astype(int)

        metrics={}
        # Compute metrics
        if not key:
            precision = precision_score(target, pred, zero_division=0)
            recall = recall_score(target, pred, zero_division=0)
            f1 = f1_score(target, pred, zero_division=0)
            metrics = {
                "Precision": precision,
                "Recall": recall,
                "F1": f1
            }

        if(key):
            roc_auc = roc_auc_score(target,pred)
            metrics["ROC_AUC"] = roc_auc

        return metrics

    def percent_overlap(self, rad1, rad2, cdist, threshold=None):
        if rad1.ndim == 2:
            rad1 = rad1.squeeze(1)
        if rad2.ndim == 2:
            rad2 = rad2.squeeze(1)
        overlap = 1 - (cdist / (rad1 + rad2))
        overlap = torch.clamp(overlap, min=0, max=1)
        if threshold is not None:
            overlap = torch.where(overlap > threshold, 1, 0)
        overlap = overlap.cpu().numpy()
        return overlap
    
    def predict(self, tag=None, load_model_path=None):
        print("Predicting...")
        if tag=="test":
            model_path = load_model_path if load_model_path else f"../result/{self.args.dataset}/model/exp_model_{self.exp_setting}.checkpoint"
            self.model.load_state_dict(torch.load(model_path, weights_only=True))

        self.model.eval()
        ground_truth=[]
        prediction=[]
        with torch.no_grad():
            for it, (encode_A, encode_B, targets, qidA, qidB) in tqdm(enumerate(self.test_loader), total = len(self.test_loader)):
                center_A, radius_A = self.model(encode_A)
                center_B, radius_B = self.model(encode_B)
                if not isinstance(targets, np.ndarray):
                    targets = targets.cpu().numpy()

                c_dist = self.center_distance(center_A, center_B)
                # Calculating the predicted class
                pred = self.percent_overlap(radius_A,radius_B,c_dist,self.args.thresh)
                
                if not isinstance(pred, np.ndarray):
                    pred = np.array(pred, dtype=torch.float)
                if not isinstance(targets, np.ndarray):
                    targets = targets.cpu().numpy()
                
                prediction.extend(pred)
                ground_truth.extend(targets)
            
        test_metrics = self.metrics(prediction, ground_truth)
            # Calculate the average of the metrics
        if(tag=="test"):
            print("Test Metrics:")
            print("Precision: ", test_metrics["Precision"])
            print("Recall: ", test_metrics["Recall"])
            print("F1 Score: ", test_metrics["F1"])
            
            with open(f'../result/{self.args.dataset}/res_{self.args.version}.json', 'a+') as f:
                d = vars(self.args)
                expt_details = {
                    "Arguments":d,
                    "Test Metrics":test_metrics
                }
                json.dump(expt_details, f, indent=4)

            return test_metrics
        else:
            return test_metrics
               

def preprocess_quora(args, indir, outdir):

    data = pd.read_csv(os.path.join(indir, "questions.csv"))
    data = data.sample(frac=1).reset_index(drop=True)
    # Remove any rows with NaN values
    data = data.dropna()
    # Remove any rows with empty strings
    data = data[data['question1'].str.strip() != '']
    data = data[data['question2'].str.strip() != '']
    # Remove any rows with NaN values in the 'is_duplicate' column
    data = data.dropna(subset=['is_duplicate'])
    # Remove any rows with empty strings in the 'is_duplicate' column
    data = data[data['is_duplicate'].astype(str).str.strip() != '']

    # Number of rows in the dataset
    num_rows = len(data)

    # Take stratified sample of 1/3 of the data
    data = stratified_sample(data, 'is_duplicate', (num_rows//3), random_state=args.seed)

    # Train test split
    train_data = data[:int(0.8*len(data))]
    test_data = data[int(0.8*len(data)):]
    # Save the train and test data
    train_data.to_csv(os.path.join(outdir, "train_questions.csv"), index=False)
    test_data.to_csv(os.path.join(outdir, "test_questions.csv"), index=False)
    print("Preprocessing done!")
    print("Train data shape: ", train_data.shape)
    print("Test data shape: ", test_data.shape)
    print("Train data saved at: ", os.path.join(outdir, "train_questions.csv"))
    print("Test data saved at: ", os.path.join(outdir, "test_questions.csv"))
    
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def stratified_sample(df: pd.DataFrame, label_col: str, sample_size: int, random_state=None) -> pd.DataFrame:
    """
    Returns a stratified sample from the dataframe with the same label distribution.

    Parameters:
    - df: The original DataFrame.
    - label_col: The name of the column containing class labels (0 or 1).
    - sample_size: Total number of rows to sample.
    - random_state: Random seed for reproducibility.

    Returns:
    - A stratified sample DataFrame.
    """
    if sample_size > len(df):
        raise ValueError("Sample size cannot be greater than the number of rows in the dataframe.")
    
    # Compute the proportions of each class
    proportions = df[label_col].value_counts(normalize=True)
    
    # Compute number of samples for each class
    sample_counts = (proportions * sample_size).round().astype(int)
    
    # Adjust for rounding errors
    diff = sample_size - sample_counts.sum()
    if diff != 0:
        # Add/subtract the difference to the most frequent class
        sample_counts.iloc[0] += diff

    # Stratified sampling
    stratified_df = pd.concat([
        df[df[label_col] == label].sample(n=n, random_state=random_state)
        for label, n in sample_counts.items()
    ])

    return stratified_df.sample(frac=1, random_state=random_state).reset_index(drop=True)  # shuffle

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--dataset", type=str, default="quora", help="dataset")
    ## Model parameters
    parser.add_argument("--pre_train", type=str, default="bert", help="Pre_trained model")
    parser.add_argument(
        "--hidden", type=int, default=64, help="dimension of hidden layers in MLP"
    )
    parser.add_argument(
        "--embed_size", type=int, default=6, help="dimension of bubble embeddings"
    )
    parser.add_argument("--dropout", type=float, default=0.05, help="dropout")
    parser.add_argument("--phi", type=float, default=0.05, help="minimum volume of bubble")
    parser.add_argument("--probwt", type=float, default=1.0, help="weight of prob loss")
    parser.add_argument(
        "--regularwt", type=float, default=1.0, help="weight of regularization loss"
    )
    parser.add_argument("--containwt",type=float,default=1.0,help="Weight for containment loss")
    parser.add_argument("--thresh",type=float,default=0.5,help="Minimum percent of overlap")

    ## Training hyper-parameters
    parser.add_argument("--expID", type=int, default=0, help="-th of experiments")
    parser.add_argument("--epochs", type=int, default=30, help="training epochs")
    parser.add_argument("--batch_size", type=int, default=512, help="training batch size")
    parser.add_argument(
        "--lr", type=float, default=2e-5, help="learning rate for pre-trained model"
    )
    parser.add_argument(
        "--lr_projection",
        type=float,
        default=1e-3,
        help="learning rate for projection layers",
    )
    parser.add_argument("--eps", type=float, default=1e-8, help="adamw_epsilon")
    parser.add_argument("--optim", type=str, default="adamw", help="Optimizer")
    parser.add_argument("--version", type=str, default="spherex", help="version of the model")

    ## Others
    parser.add_argument("--cuda", type=bool, default=True, help="use cuda for training")
    parser.add_argument("--gpu_id", type=int, default=0, help="which gpu")
    parser.add_argument("--seed",type=int,default=42,help="Seed for random generator")
    
    args = parser.parse_args()
    args.cuda = True if torch.cuda.is_available() and args.cuda else False
    if args.cuda:
        torch.cuda.set_device(args.gpu_id)
    start_time = time.time()
    print("Start time at : ")
    print_local_time()

    print("Arguments: ", args)

    set_seed(args.seed)
    indir = f"../data/{args.dataset}/"
    outdir = f"../data/{args.dataset}/processed"

    if not os.path.exists(outdir):
        os.makedirs(outdir)

    resdir = f"../result/{args.dataset}"
    if not os.path.exists(resdir):
        os.makedirs(resdir)

    exp = LabelClassfnExp(args)
    exp.train()
    exp.predict(tag="test")

    print("Time used :{:.01f}s".format(time.time() - start_time))
    print("End time at : ")
    print_local_time()
    print("************END***************")
    
if __name__ == "__main__":
    main()
    